On Notes

Author

test author

Published

January 1, 2020

Show the code


!pip install clu
Collecting clu

  Downloading clu-0.0.12-py3-none-any.whl.metadata (1.9 kB)

Requirement already satisfied: absl-py in /usr/local/lib/python3.11/dist-packages (from clu) (1.4.0)

Requirement already satisfied: etils[epath] in /usr/local/lib/python3.11/dist-packages (from clu) (1.12.2)

Requirement already satisfied: flax in /usr/local/lib/python3.11/dist-packages (from clu) (0.10.6)

Requirement already satisfied: jax in /usr/local/lib/python3.11/dist-packages (from clu) (0.5.2)

Requirement already satisfied: jaxlib in /usr/local/lib/python3.11/dist-packages (from clu) (0.5.1)

Collecting ml-collections (from clu)

  Downloading ml_collections-1.1.0-py3-none-any.whl.metadata (22 kB)

Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from clu) (2.0.2)

Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from clu) (24.2)

Requirement already satisfied: typing-extensions in /usr/local/lib/python3.11/dist-packages (from clu) (4.14.1)

Requirement already satisfied: wrapt in /usr/local/lib/python3.11/dist-packages (from clu) (1.17.2)

Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from etils[epath]->clu) (2025.3.2)

Requirement already satisfied: importlib_resources in /usr/local/lib/python3.11/dist-packages (from etils[epath]->clu) (6.5.2)

Requirement already satisfied: zipp in /usr/local/lib/python3.11/dist-packages (from etils[epath]->clu) (3.23.0)

Requirement already satisfied: msgpack in /usr/local/lib/python3.11/dist-packages (from flax->clu) (1.1.1)

Requirement already satisfied: optax in /usr/local/lib/python3.11/dist-packages (from flax->clu) (0.2.5)

Requirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.11/dist-packages (from flax->clu) (0.11.16)

Requirement already satisfied: tensorstore in /usr/local/lib/python3.11/dist-packages (from flax->clu) (0.1.74)

Requirement already satisfied: rich>=11.1 in /usr/local/lib/python3.11/dist-packages (from flax->clu) (13.9.4)

Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.11/dist-packages (from flax->clu) (6.0.2)

Requirement already satisfied: treescope>=0.1.7 in /usr/local/lib/python3.11/dist-packages (from flax->clu) (0.1.9)

Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from jax->clu) (0.4.1)

Requirement already satisfied: opt_einsum in /usr/local/lib/python3.11/dist-packages (from jax->clu) (3.4.0)

Requirement already satisfied: scipy>=1.11.1 in /usr/local/lib/python3.11/dist-packages (from jax->clu) (1.15.3)

Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=11.1->flax->clu) (3.0.0)

Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=11.1->flax->clu) (2.19.2)

Requirement already satisfied: chex>=0.1.87 in /usr/local/lib/python3.11/dist-packages (from optax->flax->clu) (0.1.89)

Requirement already satisfied: nest_asyncio in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax->clu) (1.6.0)

Requirement already satisfied: protobuf in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax->clu) (5.29.5)

Requirement already satisfied: humanize in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax->clu) (4.12.3)

Requirement already satisfied: simplejson>=3.16.0 in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax->clu) (3.20.1)

Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from chex>=0.1.87->optax->flax->clu) (0.12.1)

Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax->clu) (0.1.2)

Downloading clu-0.0.12-py3-none-any.whl (101 kB)

   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 101.8/101.8 kB 4.7 MB/s eta 0:00:00

Downloading ml_collections-1.1.0-py3-none-any.whl (76 kB)

   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.7/76.7 kB 7.8 MB/s eta 0:00:00

Installing collected packages: ml-collections, clu

Successfully installed clu-0.0.12 ml-collections-1.1.0
import jax.numpy as jnp
import jax
import flax.linen as nn
import flax
import tensorflow_datasets as tfds
from functools import partial
import numpy as np
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct  # Flax dataclasses
import optax
from clu import metrics
from typing import Sequence, Any
Show the code
import tensorflow_datasets as tfds  # TFDS for MNIST
import tensorflow as tf  # TensorFlow operations

tf.random.set_seed(0)  # set random seed for reproducibility

num_epochs = 10
batch_size = 32

train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')

train_ds = train_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255,
    'label': sample['label'],
  }
)  # normalize train set
test_ds = test_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255,
    'label': sample['label'],
  }
)  # normalize test set

# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
train_ds = train_ds.repeat(num_epochs).shuffle(1024)
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
test_ds = test_ds.shuffle(1024)
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
WARNING:absl:Variant folder /root/tensorflow_datasets/mnist/3.0.1 has no dataset_info.json
Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Show the code
class down_block(nn.Module):

    @nn.compact
    def __call__(self,
                 inputs,
                 n_filter=32,
                 max_pooling=True,
                 training=True):
        conv = nn.Conv(
            n_filter,
            3,  # filter size
            padding='SAME',
            kernel_init=nn.initializers.he_normal())(inputs)
        skip_connection = conv
        conv = nn.BatchNorm(use_running_average=not training)(conv)
        conv = nn.relu(conv)
        conv = nn.Conv(
            n_filter,
            3,  # filter size
            padding='SAME',
            kernel_init=nn.initializers.he_normal())(conv)
        conv = nn.relu(conv)

        if max_pooling:
            next_layer = nn.max_pool(conv, window_shape=(2, 2), padding='SAME')
        else:
            next_layer = conv
        return next_layer, skip_connection


class up_block(nn.Module):

    @nn.compact
    def __call__(self, inputs, skip_connection, filters, training=True):
        if skip_connection is None:
            x = inputs
        else:
            x = jnp.concatenate([inputs, skip_connection], axis=3)
        x = nn.Conv(filters, 3, padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.Conv(filters, 3, padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        return x


class UNet(nn.Module):

    @nn.compact
    def __call__(self, input):
        filter = [64, 128, 256]
        # encode
        x, temp1 = down_block()(input, filter[0])
        x, temp2 = down_block()(x, filter[1])
        x, _ = down_block()(x, filter[2], max_pooling=False)
        # decode
        x = up_block()(x, temp2, filter[1])
        x = up_block()(x, temp1, filter[0])
        x = up_block()(x, None, 1)
        return x



unet = UNet()
# print(
#     m.tabulate(jax.random.key(0),
#                jnp.ones((1, 28, 28, 1)),
#                compute_flops=True,
#                compute_vjp_flops=True))
Show the code
@struct.dataclass
class Metrics(metrics.Collection):
    loss: metrics.Average.from_output('loss')


class TrainState(train_state.TrainState):
    metrics: Metrics
    batch_stats: Any


def create_train_state(module, rng, learning_rate, momentum):
    """Creates an initial `TrainState`."""
    variables = module.init(rng, jnp.ones([1, 28, 28, 1]))
    params = variables[
        'params']  # initialize parameters by passing a template image
    batch_stats = variables[
        'batch_stats']  # initialize batch_stats by passing a template image
    tx = optax.sgd(learning_rate, momentum)
    return TrainState.create(apply_fn=module.apply,
                             params=params,
                             batch_stats=batch_stats,
                             tx=tx,
                             metrics=Metrics.empty())


@jax.jit
def train_step(state, batch):
    """Train for a single step."""

    def loss_fn(params):
        predicted, updates = state.apply_fn(
            {
                'params': params,
                'batch_stats': state.batch_stats
            },
            batch['image'],
            mutable=['batch_stats'],
            rngs={'dropout': jax.random.key(1)})
        loss = optax.losses.l2_loss(predictions=predicted,
                                    targets=batch['image']).mean()
        return loss, (predicted, updates)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (predicted, updates)), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates['batch_stats'])
    return state


@jax.jit
def compute_metrics(*, state, batch):
    predicted, updates = state.apply_fn(
        {
            'params': state.params,
            'batch_stats': state.batch_stats
        },
        batch['image'],
        mutable=['batch_stats'],
        rngs={'dropout': jax.random.key(1)})
    loss = optax.losses.l2_loss(predictions=predicted,
                                targets=batch['image']).mean()
    metric_updates = state.metrics.single_from_model_output(
        predictions=predicted, targets=batch['label'], loss=loss)
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state


num_epochs = 10
batch_size = 32

# train_ds, test_ds = get_datasets(num_epochs, batch_size)

# tf.random.set_seed(0)
init_rng = jax.random.key(0)
learning_rate = 0.01
momentum = 0.9
state = create_train_state(unet, init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs
num_steps_per_epoch = 2
print(num_steps_per_epoch)

metrics_history = {
    'train_loss': [],
    'test_loss': [],
}


test_summary_writer = tf.summary.create_file_writer('test/logdir')

for step, batch in enumerate(train_ds.as_numpy_iterator()):
    if step > 20:
      break

    # Run optimization steps over training batches and compute batch metrics
    state = train_step(
        state, batch
    )  # get updated train state (which contains the updated parameters)
    state = compute_metrics(state=state,
                            batch=batch)  # aggregate batch metrics

    if (step + 1) % num_steps_per_epoch == 0:  # one training epoch has passed
        for metric, value in state.metrics.compute().items(
        ):  # compute metrics
            metrics_history[f'train_{metric}'].append(value)  # record metrics
            with test_summary_writer.as_default():
              tf.summary.scalar(
                'train/loss', value, step=step
              )

        state = state.replace(metrics=state.metrics.empty()
                              )  # reset train_metrics for next training epoch

        # Compute metrics on the test set after each training epoch
        test_state = state
        for test_batch in test_ds.as_numpy_iterator():
            test_state = compute_metrics(state=test_state, batch=test_batch)

        for metric, value in test_state.metrics.compute().items():
            metrics_history[f'test_{metric}'].append(value)
            with test_summary_writer.as_default():
              tf.summary.scalar(
                'test/loss', value, step=step
              )

        print(f"train epoch: {(step+1) // num_steps_per_epoch}, "
              f"loss: {metrics_history['train_loss'][-1]}, ")
        print(f"test epoch: {(step+1) // num_steps_per_epoch}, "
              f"loss: {metrics_history['test_loss'][-1]}, ")
        # writer.add_image('images', grid, 0)
        # writer.add_graph(model, images)
2
train epoch: 1, loss: 0.23241570591926575, 
test epoch: 1, loss: 0.19601216912269592, 
train epoch: 2, loss: 0.1533064991235733, 
test epoch: 2, loss: 0.15013296902179718, 
train epoch: 3, loss: 0.12592682242393494, 
test epoch: 3, loss: 0.11877958476543427, 
train epoch: 4, loss: 0.0943879559636116, 
test epoch: 4, loss: 0.08934623003005981, 
train epoch: 5, loss: 0.06853969395160675, 
test epoch: 5, loss: 0.06333110481500626, 
train epoch: 6, loss: 0.04726167768239975, 
test epoch: 6, loss: 0.04230469465255737, 
train epoch: 7, loss: 0.02610059641301632, 
test epoch: 7, loss: 0.026588963344693184, 
train epoch: 8, loss: 0.01785922423005104, 
test epoch: 8, loss: 0.015624734573066235, 
train epoch: 9, loss: 0.01218641921877861, 
test epoch: 9, loss: 0.008579996414482594, 
train epoch: 10, loss: 0.006207008380442858, 
test epoch: 10, loss: 0.004718703217804432, 
Show the code
# %load_ext tensorboard
%tensorboard --logdir 'test/logdir'
Show the code
np.random.rand((1,28,28,))
Show the code
conv = nn.ConvTranspose(features=2, kernel_size=3)
conv.init_with_output(jax.random.key(0), x)